// https://www.lintcode.com/problem/binary-tree-maximum-node/

public class Solution {
    /*
     * @param root: the root of tree
     * @return: the max node
     */
    public TreeNode maxNode(TreeNode root) {
        // write your code here
        if (root == null) {
            return root;
        }
        else if ((root.left == null) && (root.right == null)) {
            return root;
        }
        else {
            TreeNode ret = root;
            if (root.left != null) {
                TreeNode leftMax = maxNode(root.left);
                if (leftMax.val > ret.val) {
                    ret = leftMax;
                }
            }
            if (root.right != null) {
                TreeNode leftMax = maxNode(root.right);
                if (leftMax.val > ret.val) {
                    ret = leftMax;
                }
            }
            return ret;
        }
    }
}